{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f86ede39-8d9f-4da9-bc12-955f2fddd484",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Copyright 2023 Google LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f3cbf47-a52b-48b1-9bd3-3435f92f2174",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Copyright 2023 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22de629b-581f-4335-9e7b-f73221d8dbcb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# ControlNet depth with StyleAligned over SDXL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "486b7ebb-c483-4bf0-ace8-f8092c2d1f23",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL\n",
    "from diffusers.utils import load_image\n",
    "from transformers import DPTImageProcessor, DPTForDepthEstimation\n",
    "import torch\n",
    "import mediapy\n",
    "import sa_handler\n",
    "import pipeline_calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a7e85e7-b5cf-45b2-946a-5ba1e4923586",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# init models\n",
    "\n",
    "depth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(\"cuda\")\n",
    "feature_processor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n",
    "\n",
    "controlnet = ControlNetModel.from_pretrained(\n",
    "    \"diffusers/controlnet-depth-sdxl-1.0\",\n",
    "    variant=\"fp16\",\n",
    "    use_safetensors=True,\n",
    "    torch_dtype=torch.float16,\n",
    ").to(\"cuda\")\n",
    "vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(\"cuda\")\n",
    "pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
    "    controlnet=controlnet,\n",
    "    vae=vae,\n",
    "    variant=\"fp16\",\n",
    "    use_safetensors=True,\n",
    "    torch_dtype=torch.float16,\n",
    ").to(\"cuda\")\n",
    "pipeline.enable_model_cpu_offload()\n",
    "\n",
    "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
    "                                      share_layer_norm=False,\n",
    "                                      share_attention=True,\n",
    "                                      adain_queries=True,\n",
    "                                      adain_keys=True,\n",
    "                                      adain_values=False,\n",
    "                                     )\n",
    "handler = sa_handler.Handler(pipeline)\n",
    "handler.register(sa_args, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ca26b4-9061-4012-9400-8d97ef212d87",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# get depth maps\n",
    "\n",
    "image = load_image(\"./example_image/train.png\")\n",
    "depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)\n",
    "depth_image2 = load_image(\"./example_image/sun.png\").resize((1024, 1024))\n",
    "mediapy.show_images([depth_image1, depth_image2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8f56fe4-559f-49ff-a2d8-460dcfeb56a0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# run ControlNet depth with StyleAligned\n",
    "\n",
    "reference_prompt = \"a poster in flat design style\"\n",
    "target_prompts = [\"a train in flat design style\", \"the sun in flat design style\"]\n",
    "controlnet_conditioning_scale = 0.8\n",
    "num_images_per_prompt = 3 # adjust according to VRAM size\n",
    "latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
    "for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):\n",
    "    latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
    "    images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],\n",
    "                                            image=deph_map,\n",
    "                                            num_inference_steps=50,\n",
    "                                            controlnet_conditioning_scale=controlnet_conditioning_scale,\n",
    "                                            num_images_per_prompt=num_images_per_prompt,\n",
    "                                           latents=latents)\n",
    "    \n",
    "    mediapy.show_images([images[0], deph_map] +  images[1:], titles=[\"reference\", \"depth\"] + [f'result {i}' for i in range(1, len(images))])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "437ba4bd-6243-486b-8ba5-3b7cd661d53a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}